{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "NMNIST.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "lRuXm7d4UpIT"
      },
      "source": [
        "!pip install snntorch"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qtbidebHU1PP"
      },
      "source": [
        "import snntorch as snn\n",
        "from snntorch.spikevision import spikedata "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yr8wE2ybU4RQ"
      },
      "source": [
        "## Download Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rhskczZfU3Wo"
      },
      "source": [
        "# note that a default transform is already applied\n",
        "\n",
        "train_ds = spikedata.NMNIST(\"data/nmnist\", train=True, num_steps=300, dt=1000) # dt is the # of microseconds integrated\n",
        "test_ds = spikedata.NMNIST(\"data/nmnist\", train=False, num_steps=300, dt=1000)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IUoggS_uVAB3"
      },
      "source": [
        "train_ds "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KKPN98YwVD4V"
      },
      "source": [
        "## Create DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aC1cno26VEu8"
      },
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "train_dl = DataLoader(train_ds, shuffle=True, batch_size=64)\n",
        "test_dl = DataLoader(test_ds, shuffle=False, batch_size=64)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "B535MTroVFMO"
      },
      "source": [
        "# get a feel for the data\n",
        "train_dl.dataset[0][0].size()  # index into one sample"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-n3lDA8PVKMv"
      },
      "source": [
        "## Visualize"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-zNTZOC3VIKg"
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import snntorch.spikeplot as splt\n",
        "from IPython.display import HTML\n",
        "\n",
        "n = 20000  # random index into a sample\n",
        "\n",
        "# flatten on-spikes and off-spikes into one channel\n",
        "a = (train_dl.dataset[n][0][:, 0] + train_dl.dataset[n][0][:, 1])\n",
        "print(f\"in case you're blind AF, the target is: {train_dl.dataset[n][1]}\")\n",
        "\n",
        "#  Plot\n",
        "fig, ax = plt.subplots()\n",
        "anim = splt.animator(a, fig, ax, interval=30)\n",
        "HTML(anim.to_html5_video())\n",
        "\n",
        "# anim.save('nmnist_animation.mp4', writer = 'ffmpeg', fps=50)  "
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}